from typing import List, Dict, Any, Callable
import copy
import numpy as np


class GreedyImportantFirstStrategy:
    """
    GreedyImportantFirstStrategy (final revised version):
    - In each round, group all hyperparameters based on importance, each group has top_k parameters (the last group may have fewer than K).
    - Allocate step_trials to each group according to the weights within the group.
    - When optimizing each group, jointly optimize parameters within the group, and lock parameters outside the group to the current best configuration.
    - Use the entire history for warm start in each round.
    - Dynamically update importance after each round.
    - Supports full parameter group optimization within a round (controlled by full_group_ratio if no improvement).
    """

    def __init__(
        self,
        search_space: Dict[str, Any],
        initial_trials: List[Dict[str, Any]],
        optimizer_builder: Callable[
            [Dict[str, Any], List[Dict[str, Any]], Dict[str, Any], int], Any
        ],
        importance_evaluator: Callable[
            [List[Dict[str, Any]], List[float]], Dict[str, float]
        ],
        step_trials: int = 20,
        max_total_trials: int = 100,
        min_trials_for_importance: int = 10,
        default_config: Dict[str, Any] = None,
        full_group_ratio: float = 0.1,
        top_k: int = 2,
    ):
        self.search_space = search_space
        self.history = copy.deepcopy(initial_trials)
        self.optimizer_builder = optimizer_builder
        self.importance_evaluator = importance_evaluator
        self.step_trials = step_trials
        self.max_total_trials = max_total_trials
        self.min_trials_for_importance = min_trials_for_importance
        self.default_config = default_config or {}
        self.full_group_ratio = full_group_ratio
        self.max_full_trials = max(
            1, int(self.full_group_ratio * self.max_total_trials)
        )
        self.full_trials_used = 0
        self.logs = []
        self.top_k = top_k

    def run(self) -> List[Dict[str, Any]]:
        total_trials_used = len(self.history)
        param_names = list(self.search_space.keys())

        # Initialize best configuration and score
        best_trial = max(self.history, key=lambda t: t["score"])
        current_best_config = dict(best_trial["config"])
        current_best_score = best_trial["score"]

        # Number of full space optimizations
        max_full_trials = self.max_full_trials
        full_trials_used = 0
        round_idx = 0

        while total_trials_used < self.max_total_trials:
            configs = [t["config"] for t in self.history]
            scores = [t["score"] for t in self.history]

            # 1. Calculate importance
            if len(self.history) < self.min_trials_for_importance:
                # Not enough trials, assign equal importance
                importance = {k: 1.0 for k in self.search_space}
            else:
                # Use provided evaluator to get importance for each parameter
                importance = self.importance_evaluator(configs, scores)

            # 2. Sort by importance and group
            sorted_params = sorted(importance.items(), key=lambda x: -x[1])
            sorted_param_names = [x[0] for x in sorted_params]
            # Group parameters by top_k
            groups = [
                sorted_param_names[i : i + self.top_k]
                for i in range(0, len(sorted_param_names), self.top_k)
            ]

            # 3. Allocate trials to each group
            group_weights = [sum(importance[p] for p in group) for group in groups]
            total_weight = sum(group_weights)
            # Allocate trials proportional to group importance
            group_trials = [
                max(1, int(round(self.step_trials * w / total_weight)))
                for w in group_weights
            ]
            # Add remaining trials to the group with the largest weight
            group_trials_sum = sum(group_trials)
            if group_trials_sum < self.step_trials:
                max_idx = np.argmax(group_weights)
                group_trials[max_idx] += self.step_trials - group_trials_sum

            # 4. Optimize each group in turn
            round_has_improvement = False
            remaining_trials = self.max_total_trials - total_trials_used 
            
            for group, budget in zip(groups, group_trials):
                
                allowed = min(budget, remaining_trials)
                if allowed <= 0:
                    break
                
                # Build search space for current group, lock others to best config
                subspace = {k: self.search_space[k] for k in group}
                fixed_config = {
                    k: current_best_config[k]
                    for k in self.search_space
                    if k not in group
                }

                # optimizer_builder(subspace, history, fixed_config, max_trials)
                optimizer = self.optimizer_builder(
                    subspace, self.history, fixed_config, allowed
                )
                try:
                    optimizer_results = optimizer.optimize()
                except Exception as e:
                    print(f"[GIF] Optimizer failed on group {group}: {e}")
                    continue

                new_trials = []
                for config, result, elapsed_time in optimizer_results:
                    # Merge full config
                    full_config = dict(current_best_config)
                    full_config.update(config)
                    trial_data = {
                        "config": full_config,
                        "score": result,
                        "elapsed_time": elapsed_time,
                        "round": round_idx,
                        "group": group,
                    }
                    new_trials.append(trial_data)

                self.history.extend(new_trials)
                n_added = len(new_trials)
                total_trials_used += n_added
                remaining_trials -= n_added 
                if remaining_trials <= 0:
                    break # No more trials left to allocate

                # Check for improvement
                group_best = max(new_trials, key=lambda t: t["score"])
                if group_best["score"] > current_best_score:
                    current_best_score = group_best["score"]
                    current_best_config = dict(group_best["config"])
                    round_has_improvement = True

                # Logging
                self.logs.append(
                    {
                        "round": round_idx,
                        "group": group,
                        "trials": len(new_trials),
                        "best_score": group_best["score"],
                        "importance_snapshot": dict(importance),
                    }
                )

            # 5. If no improvement in the round and there is full quota, perform full space optimization
            full_quota_left = max_full_trials - full_trials_used
            remaining_trials = self.max_total_trials - total_trials_used
            # Dynamically allocate full_budget for full space optimization
            full_budget = (
                int(
                    full_quota_left
                    / ((remaining_trials) // self.step_trials + 1)
                )
                if remaining_trials > 0 else 0
            )
            # Ensure full_budget does not exceed remaining_trials
            full_budget = min(full_budget, remaining_trials)
            if (
                not round_has_improvement
                and full_quota_left > 0
                and full_budget > 0
                and total_trials_used < self.max_total_trials
            ):
                group = param_names
                subspace = {k: self.search_space[k] for k in group}
                fixed_config = {}
                optimizer = self.optimizer_builder(
                    subspace, self.history, fixed_config, full_budget
                )
                try:
                    optimizer_results = optimizer.optimize()
                except Exception as e:
                    print(f"[GIF] Full optimizer failed: {e}")
                    continue
                new_trials = []
                for config, result, elapsed_time in optimizer_results:
                    full_config = dict(current_best_config)
                    full_config.update(config)
                    trial_data = {
                        "config": full_config,
                        "score": result,
                        "elapsed_time": elapsed_time,
                        "round": round_idx,
                        "group": group,
                    }
                    new_trials.append(trial_data)
                self.history.extend(new_trials)
                n_added = len(new_trials)
                total_trials_used += n_added
                full_trials_used += n_added
                remaining_trials -= n_added
                if remaining_trials <= 0:
                    break

                group_best = max(new_trials, key=lambda t: t["score"])
                if group_best["score"] > current_best_score:
                    current_best_score = group_best["score"]
                    current_best_config = dict(group_best["config"])
                    round_has_improvement = True

                self.logs.append(
                    {
                        "round": round_idx,
                        "group": group,
                        "trials": n_added,
                        "best_score": group_best["score"],
                        "importance_snapshot": dict(importance),
                        "full_group": True,
                    }
                )
            round_idx += 1
        self.current_best_config = current_best_config
        self.current_best_score = current_best_score
        return self.history
